"""
上述训练代码仅仅提供简单训练的范例，对大规模数据，
例如Places2可以进行数据替换，直接放置在对应cache中即可；
此外configuration.json(~/.cache/modelscope/hub/damo/cv_fft_inpainting_lama/)可以进行自定义修改；

在Places2上使用Adam优化器，初始学习率为1e-3，训练1M iterations。

测试时主要的预处理如下：

Pad：图像高宽镜像填充至8的倍数
Normalize：图像归一化，像素大小由0-255归一化到0-1即可

采用我们提供的PlacesToydataset数据进行finetune后得到的结果FID一般为 30-80 之间（由于我们提供的数据少，而FID的计算依赖大量数据，故此处FID结果偏高且不稳定）


如果你觉得这个该模型对有所帮助，请考虑引用下面的相关的论文：

@article{suvorov2021resolution,
  title={Resolution-robust Large Mask Inpainting with Fourier Convolutions},
  author={Suvorov, Roman and Logacheva, Elizaveta and Mashikhin, Anton and Remizova, Anastasia and Ashukha, Arsenii and Silvestrov, Aleksei and Kong, Naejin and Goka, Harshith and Park, Kiwoong and Lempitsky, Victor},
  journal={arXiv preprint arXiv:2109.07161},
  year={2021}
}
@article{kulshreshtha2022feature,
  title={Feature Refinement to Improve High Resolution Image Inpainting},
  author={Kulshreshtha, Prakhar and Pugh, Brian and Jiddi, Salma},
  journal={arXiv preprint arXiv:2206.13644},
  year={2022}
}
"""
import os
import shutil
import tempfile
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.models.cv.image_inpainting import FFTInpainting
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.constant import ModelFile
from modelscope.utils.test_utils import test_level



model_id = 'damo/cv_fft_inpainting_lama'
cache_path = snapshot_download(model_id)
cfg = Config.from_file(
    os.path.join(cache_path, ModelFile.CONFIGURATION))

train_data_cfg = ConfigDict(
    name='PlacesToydataset',
    split='train',
    mask_gen_kwargs=cfg.dataset.mask_gen_kwargs,
    out_size=cfg.dataset.train_out_size,
    test_mode=False)

test_data_cfg = ConfigDict(
    name='PlacesToydataset',
    split='test',
    mask_gen_kwargs=cfg.dataset.mask_gen_kwargs,
    out_size=cfg.dataset.val_out_size,
    test_mode=True)

train_dataset = MsDataset.load(
    dataset_name=train_data_cfg.name,
    split=train_data_cfg.split,
    mask_gen_kwargs=train_data_cfg.mask_gen_kwargs,
    out_size=train_data_cfg.out_size,
    test_mode=train_data_cfg.test_mode)
assert next(
    iter(train_dataset.config_kwargs['split_config'].values()))

test_dataset = MsDataset.load(
    dataset_name=test_data_cfg.name,
    split=test_data_cfg.split,
    mask_gen_kwargs=test_data_cfg.mask_gen_kwargs,
    out_size=test_data_cfg.out_size,
    test_mode=test_data_cfg.test_mode)
assert next(
    iter(test_dataset.config_kwargs['split_config'].values()))

kwargs = dict(
    model=model_id,
    train_dataset=train_dataset,
    eval_dataset=test_dataset)

trainer = build_trainer(
    name=Trainers.image_inpainting, default_args=kwargs)
trainer.train()